package main

import (
	"fmt"
	"go/ast"
	"go/printer"
	"go/token"
	"io"
	"os"
	"sort"
	"strings"
)

// A module is a collection of COM declarations.
type module struct {
	// packageName is the name of the package that these declarations belong to.
	packageName string

	// fileSet is the FileSet representing the text that the declarations were
	// parsed from.
	fileSet *token.FileSet

	// dllFuncs is the functions imported from DLLs, organized by DLL name.
	dllFuncs map[string][]*ast.FuncDecl

	// interfaces is the COM interfaces declared in the module.
	interfaces map[string]*iface

	// imports is a list of packages that need to be imported.
	imports map[string]*ast.ImportSpec

	// miscDecls holds declarations that will be printed unchanged.
	miscDecls []ast.Decl

	printConfig printer.Config
}

func newModule(fileSet *token.FileSet) *module {
	return &module{
		imports:    make(map[string]*ast.ImportSpec),
		dllFuncs:   make(map[string][]*ast.FuncDecl),
		interfaces: make(map[string]*iface),
		fileSet:    fileSet,
		printConfig: printer.Config{
			Mode: printer.TabIndent | printer.UseSpaces,
		},
	}
}

// write prints m's expanded declarations to w.
func (m *module) write(w io.Writer) error {
	fmt.Fprintln(w, "package", m.packageName)
	fmt.Fprintln(w)
	if m.packageName != "com" {
		m.addImport("code.google.com/p/com-and-go")
	}
	m.addImport("unsafe")

	fmt.Fprintln(w, "// generated by", strings.Join(os.Args, " "))
	fmt.Fprintln(w)

	if len(m.imports) > 0 {
		paths := make([]string, 0, len(m.imports))
		for s := range m.imports {
			paths = append(paths, s)
		}
		sort.Strings(paths)

		fmt.Fprintln(w, "import (")
		for _, imp := range paths {
			fmt.Fprint(w, "\t")
			err := m.printConfig.Fprint(w, m.fileSet, m.imports[imp])
			if err != nil {
				return err
			}
			fmt.Fprintln(w)
		}
		fmt.Fprintln(w, ")")
		fmt.Fprintln(w)
	}

	fmt.Fprint(w, "var _ unsafe.Pointer\n\n")

	if len(m.dllFuncs) > 0 {
		fmt.Fprintln(w, "var (")
		for dllName := range m.dllFuncs {
			fmt.Fprintf(w, "\tmod%s = syscall.NewLazyDLL(\"%s.dll\")\n", dllName, dllName)
		}
		fmt.Fprintln(w)

		for dllName, funcs := range m.dllFuncs {
			for _, f := range funcs {
				fmt.Fprintf(w, "\tproc%s = mod%s.NewProc(%q)\n", f.Name, dllName, f.Name)
			}
		}

		fmt.Fprintln(w, ")")
		fmt.Fprintln(w)

		for _, funcs := range m.dllFuncs {
			for _, f := range funcs {
				err := m.writeFunc(w, f)
				if err != nil {
					return err
				}
			}
		}
	}

	ifNames := make([]string, 0, len(m.interfaces))
	for ifName := range m.interfaces {
		err := m.calcVTStart(ifName, 0)
		if err != nil {
			return err
		}
		ifNames = append(ifNames, ifName)
	}
	sort.Strings(ifNames)

	for _, i := range ifNames {
		err := m.writeInterface(w, m.interfaces[i])
		if err != nil {
			return err
		}
	}

	for _, decl := range m.miscDecls {
		err := m.printConfig.Fprint(w, m.fileSet, decl)
		if err != nil {
			return err
		}
		fmt.Fprintln(w)
	}

	return nil
}

// loadFile loads the declarations from f into m.
func (m *module) loadFile(f *ast.File) error {
	if f.Name.Name != m.packageName {
		if m.packageName == "" {
			m.packageName = f.Name.Name
		} else {
			return fmt.Errorf("mismatched package names (%s and %s)", m.packageName, f.Name.Name)
		}
	}

	for _, imp := range f.Imports {
		m.imports[imp.Path.Value] = imp
	}

	for _, decl := range f.Decls {
		switch d := decl.(type) {
		case *ast.FuncDecl:
			if d.Body == nil && d.Recv != nil {
				receiver := d.Recv.List[0]
				if receiver.Names == nil {
					if ident, ok := receiver.Type.(*ast.Ident); ok {
						m.dllFuncs[ident.Name] = append(m.dllFuncs[ident.Name], d)
						m.addImport("syscall")
						continue
					}
				}
			}

		case *ast.GenDecl:
			switch d.Tok {
			case token.IMPORT:
				continue

			case token.TYPE:
				for _, spec := range d.Specs {
					ts := spec.(*ast.TypeSpec)
					if ts.Doc == nil {
						ts.Doc = d.Doc
					}
					IF, err := newIface(ts)
					if err != nil {
						return err
					}
					m.interfaces[IF.name] = IF
				}
				continue
			}
		}
		m.miscDecls = append(m.miscDecls, decl)
	}

	return nil
}

// addImport adds pkg to m's list of imports.
func (m *module) addImport(pkg string) {
	m.imports[pkg] = &ast.ImportSpec{
		Path: &ast.BasicLit{
			Kind:  token.STRING,
			Value: fmt.Sprintf("%q", pkg),
		},
	}
}
